import pybullet as p 
import pickle 
import pybullet_data
import pybullet_utils.bullet_client as bc
import os
import IPython


def check_body_collision(body1, body2, linkIndex1=-1, linkIndex2=-1):

        if isinstance(linkIndex1, list):
            closestPts = []
            for index in linkIndex1:
                idx_closest = _p.getClosestPoints(bodyA=body1,
                                        bodyB=body2,
                                        linkIndexA=index,
                                        linkIndexB=linkIndex2,
                                        distance=0.5)
                closestPts.extend(idx_closest)
        else:
            closestPts = _p.getClosestPoints(bodyA=body1,
                                              bodyB=body2,
                                              linkIndexA=linkIndex1,
                                              linkIndexB=linkIndex2,
                                              distance=0.5)
        minDistance = 0.5
        if len(closestPts) == 0:
            return minDistance, False
        else:
            for pt in closestPts:
                if pt[8] < minDistance:
                    minDistance = pt[8]
                if pt[8] < 0:
                    return  minDistance, True
            return  minDistance, False


def load_stl(_p,phys,path):
        plan_id = _p.loadURDF(fileName='plane.urdf',
                                          basePosition=[0,0,0],
                                          baseOrientation=_p.getQuaternionFromEuler([0,0,0]),
                                          physicsClientId=phys)
        col_shape_id = _p.createCollisionShape(
                        shapeType=p.GEOM_MESH,
                        fileName=path,
                        flags=p.URDF_INITIALIZE_SAT_FEATURES|p.GEOM_FORCE_CONCAVE_TRIMESH
                    )
        viz_shape_id = _p.createCollisionShape(
                        shapeType=p.GEOM_MESH,
                        fileName=path,
                    )

        body_id = _p.createMultiBody(
                    baseVisualShapeIndex=viz_shape_id,
                    baseCollisionShapeIndex=col_shape_id,
                    basePosition=(0, 0, -0.25),
                    baseOrientation=(0, 0, 0, 1),
                )

        return body_id

def load_robot(_p,phys_client,urdf_path):
    return _p.loadURDF(fileName=urdf_path,
                                    basePosition=[0.,0.,0.], 
                                    baseOrientation=[0.,0.,0.,1.],
                                    physicsClientId=phys_client)

def setPose(robot_id,pose,bullet_client,phys_client):
    pos = [pose[0],pose[1],0.0]
    orn = [0,0,pose[2]]
    quat = bullet_client.getQuaternionFromEuler(orn)
    bullet_client.resetBasePositionAndOrientation(bodyUniqueId=robot_id,physicsClientId=phys_client,posObj=pos,ornObj=quat)
    return pose


robot = "limo"
name = "6_small"

loc_path = "./experiments/" + robot + "_" + name +".pickle"

env_path = "./envs/" + name + ".stl"

# robot_path = "./data/husky/husky.urdf"
# robot_path = "./data/simple_robot.urdf"
robot_path = "./data/limo/limo.urdf"


default_link = 4

with open(loc_path,"rb") as f: 
    data = pickle.load(f)

init, goal = data["init"], data["goal"]

mode = p.DIRECT

_p = bc.BulletClient(connection_mode=mode)
phys = _p._client
_p.setAdditionalSearchPath(pybullet_data.getDataPath())
_p.setGravity(gravX=0, gravY=0, gravZ=-10, physicsClientId=phys)
env_id = load_stl(_p,phys,os.path.abspath(env_path))
robot_id = load_robot(_p,phys,robot_path)


for i,pair in enumerate(zip(init,goal)):
    start, end = pair
    start_t, end_t = [], []

    # for j in range(len(start) - 1): 
    #     start_t.append(start[j] * 3)
    #     end_t.append(end[j] * 3)
    
    # start_t.append(start[2])
    # end_t.append(end[2])
    start_t = start
    end_t = end 


    setPose(robot_id,start_t,_p,phys)
    _, cc_start = check_body_collision(robot_id,env_id,default_link)
    if cc_start: 
        while True:
            print("test {} init state in collision: {}".format(i+1,start_t))
            s = input("Enter new config: ")
            new_s = [float(k) for k in s.split(",")]
            setPose(robot_id,new_s,_p,phys)
            _, cc_start = check_body_collision(robot_id,env_id,default_link)
            if not cc_start:
                input("{} init {} now okay, continue?".format(i,new_s))
                for j in range(2):
                    new_s[j] = new_s[j] / 3.0
                print("storing {}".format(new_s))
                init[i] = new_s
                break
    else:
        print("test {} start {} okay".format(i+1, init[i]))
    # IPython.embed()
    setPose(robot_id,end_t,_p,phys)
    _, cc_end = check_body_collision(robot_id,env_id,default_link)
    
    if cc_end: 
        while True:
            print("test {} goal state in collision: {}".format(i+1,end_t))
            s = input("Enter new config: ")
            new_s = [float(k) for k in s.split(",")]
            setPose(robot_id,new_s,_p,phys)
            _, cc_start = check_body_collision(robot_id,env_id,default_link)
            if not cc_start:
                input("{} goal {} now okay, continue?".format(i,new_s))
                for j in range(2):
                    new_s[j] = new_s[j] / 3.0
                print("storing {}".format(new_s))
                goal[i] = new_s
                break
    else:
        print("test {} goal {} okay".format(i+1,goal[i]))
    if i > 4:
        break

new_data = {"init": init, "goal": goal}
with open(loc_path,"wb") as f: 
    pickle.dump(new_data,f)
